from ensemble_boxes import *
import json, copy
from pycocotools.coco import COCO
from pycocotools.cocoeval import COCOeval
import numpy as np
import json
import matplotlib.pyplot as plt
from matplotlib.pyplot import MultipleLocator
import os
from sko.GA import GA

'''
    precision: [T, R, K, A, M]  (iou, recall, cls, area range, max dets)
               [10, 101, 3, 4, 3]

    T 表示COCO计算时采用的10个IoU值, 从0.5到0.95每间隔0.05取一个值
    R 表示COCO计算时采用的每一个概率阈值，这里是从0到1每间隔0.01（即一个百分点）取一个值, 共101的值
    K 表示检测任务中检测的目标类别数，假设针对COCO数据集就为80
    A 表示检测任务中针对的目标尺度类型 共4个值:
        第一个表示没有限制, 
        第二个代表小目标（area < 32^2）
        第三个代表中等目标（32^2 < area < 96^2）
        第四个代表大目标（area > 96^2）
    M 表示每张图片最大检测目标个数，COCO中有1,10,100共3个值
'''

def _summarize(coco, ap=True, catId=None, iouThr=None, areaRng='all', maxDets=100):
    p = coco.params
    iStr = ' {:<18} {} @[ IoU={:<9} | area={:>6s} | maxDets={:>3d} ] = {:0.3f}'
    titleStr = 'Average Precision' if ap == 1 else 'Average Recall'
    typeStr = '(AP)' if ap == 1 else '(AR)'
    iouStr = '{:0.2f}:{:0.2f}'.format(p.iouThrs[0], p.iouThrs[-1]) \
        if iouThr is None else '{:0.2f}'.format(iouThr)

    # areaRng ['all', 'small', 'medium', 'large']
    aind = [i for i, aRng in enumerate(p.areaRngLbl) if aRng == areaRng]
    mind = [i for i, mDet in enumerate(p.maxDets) if mDet == maxDets]

    if ap:
        # dimension of precision: [TxRxKxAxM]
        s = coco.eval['precision']
        # IoU
        if iouThr is not None:
            t = np.where(iouThr == p.iouThrs)[0]
            s = s[t]

        if isinstance(catId, int):
            s = s[:, :, catId, aind, mind]
        else:
            s = s[:, :, :, aind, mind]

    else:
        # dimension of recall: [TxKxAxM]
        s = coco.eval['recall']
        if iouThr is not None:
            t = np.where(iouThr == p.iouThrs)[0]
            s = s[t]

        if isinstance(catId, int):
            s = s[:, catId, aind, mind]
        else:
            s = s[:, :, aind, mind]

    if len(s[s > -1]) == 0:
        mean_s = -1
    else:
        mean_s = np.mean(s[s > -1])

    print_string = iStr.format(titleStr, typeStr, iouStr, areaRng, maxDets, mean_s)
    return mean_s, print_string



def fusion(fusion_type, josn_list, score, image_info_path, iou_thr, skip_box_thr, sigma):

    ori_weights = get_weights(score)
    imgid_info = get_image_info(image_info_path)

    result_list = []
    for json_path in josn_list:
        result = get_bbox_ann(json_path)
        result_list.append(result)

    # 按图片fusion
    new_result = []
    for img_id in imgid_info:
        boxes_list = []
        scores_list = []
        labels_list = []
        weights = copy.deepcopy(ori_weights)
        W, H = imgid_info[img_id]
        for i, res in enumerate(result_list):
            if img_id in res:
                boxes_list.append(res[img_id]['boxes'])
                scores_list.append(res[img_id]['scores'])
                labels_list.append(res[img_id]['labels'])
            else:
                if i < len(weights):
                    del weights[i]

        if fusion_type == 'nms':
            boxes, scores, labels = nms(boxes_list, scores_list, labels_list, weights=weights, iou_thr=iou_thr)
        elif fusion_type == 'soft_nms':
            boxes, scores, labels = soft_nms(boxes_list, scores_list, labels_list, weights=weights, iou_thr=iou_thr, sigma=sigma, thresh=skip_box_thr)
        elif fusion_type == 'nmw':
            boxes, scores, labels = weighted_boxes_fusion(boxes_list, scores_list, labels_list, weights=weights, iou_thr=iou_thr, skip_box_thr=skip_box_thr)
        elif fusion_type == 'wbf':
            boxes, scores, labels = non_maximum_weighted(boxes_list, scores_list, labels_list, weights=weights, iou_thr=iou_thr, skip_box_thr=skip_box_thr)
        #
        # 反归一化
        for i, box in enumerate(boxes):
            new_result.append({
                'image_id': img_id,
                'category_id': int(labels[i]),
                "score": float(scores[i]),
                "bbox": [float(box[0]*W), float(box[1]*H), float((box[2]-box[0])*W), float((box[3]-box[1])*H)]
            })

    json_str = json.dumps(new_result)
    with open('./val_soft_nms.json', 'w') as json_file:
        json_file.write(json_str)

def get_image_info(image_info_path):
    with open(image_info_path, 'r') as f:
        image_info = json.load(f)['images']
    
    imgid_info = {}
    for image in image_info:
        imgid_info[image['id']] = [image['width'], image['height']]
 
    return imgid_info
    
def get_weights(score):
    factor = max(score)
    weights = [s / factor for s in score]
    return weights

def get_bbox_ann(json_path):
    with open(json_path, 'r') as f:
        bbox_data = json.load(f)

    imgid_info = get_image_info(image_info_path)

    result = {}

    # 按图片id处理
    for ann in bbox_data:
        W, H = imgid_info[ann['image_id']]
        bbox = ann['bbox']
        
        labels_list = ann['category_id']
        scores_list = ann['score']
        boxes_list = [bbox[0] / W, bbox[1] / H, 
                     (bbox[0] + bbox[2]) / W, 
                     (bbox[1] + bbox[3]) / H]

        if ann['image_id'] not in result:
            result[ann['image_id']] = {
                "boxes": [],
                "scores": [],
                "labels": []
            }
            result[ann['image_id']]["boxes"].append(boxes_list)
            result[ann['image_id']]["scores"].append(scores_list)
            result[ann['image_id']]["labels"].append(labels_list)
        else:
            result[ann['image_id']]["boxes"].append(boxes_list)
            result[ann['image_id']]["scores"].append(scores_list)
            result[ann['image_id']]["labels"].append(labels_list)


    return result

def get_mAP(label_json_path, res_json_path):
    print('---------'*10)
    coco_gt = COCO(label_json_path)
    coco_dt = coco_gt.loadRes(res_json_path)
    coco_eval = COCOeval(coco_gt, coco_dt, 'bbox')
    coco_eval.evaluate()
    coco_eval.accumulate()
    mAP, _ = _summarize(coco_eval)

    return float(mAP)

if __name__ == '__main__':
    label_json_path = './val.json'
    res_json_path = './val_soft_nms.json'
    categories_info = {
        1: '1 (rect_eye)',
        2: '2 (sphere_eye)',
        3: '3 (box_eye)',
    }

    josn_list = ['./bbox_0.json',
                './bbox_1.json',
                './bbox_2.json',
                './bbox_3.json']

    result = {
        'nms': {},
        'soft_nms': {},
        'nmw': {},
        'wbf': {}
    }

    fusion_type_list = ['nms', 'soft_nms', 'nmw', 'wbf']
    iou_thr_range = np.arange(0.4, 0.91, 0.02)
    ##################### 超参 ###########################
    score = [2, 1.5, 1, 1]
    iou_thr = 0.5
    skip_box_thr = 0.001
    sigma = 0.1
    fusion_type = 'nms'
    ######################################################

    image_info_path = './val.json'

    def cost(x):
        # eg [1, 1, 1, 1, 0.5, 0.001, 0.1]
        score = x[:4]
        iou_thr = x[4]
        skip_box_thr = x[5]
        sigma = x[6]
        fusion(fusion_type, josn_list, score, image_info_path, iou_thr, skip_box_thr, sigma)
        mAP = get_mAP(label_json_path, res_json_path)
        return mAP

    ga = GA(func=cost, n_dim=7, size_pop=50, max_iter=100, 
            lb=[1, 1, 1, 1, 0.4, 0.0001, 0.1], 
            ub=[4, 4, 4, 4, 0.9, 0.005, 0.5],
            precision=[0.1, 0.1, 0.1, 0.1, 0.02, 0.0002, 0.1])
    res = ga.run()
    print(res)

    # for fusion_type in fusion_type_list:
    #     iou_map = []
    #     for iou_thr in iou_thr_range:
    #         fusion(fusion_type, josn_list, score, image_info_path, iou_thr, skip_box_thr, sigma)
    #         mAP = get_mAP(label_json_path, res_json_path)
    #         print(iou_thr, " --- ", mAP)
    #         iou_map.append([float(iou_thr), mAP])
    #     break
    # result['nms']['iou_thr'] = np.array(iou_map)
    # print(result['nms']['iou_thr'])
    # max_index = np.argmax(result['nms']['iou_thr'], axis=0)[-1]
    # print('\n iou     mAP')
    # print(result['nms']['iou_thr'][max_index, :])